{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 测试无琐碎绑定的重写"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.script import relax as R\n",
    "from tvm.script import tir as T\n",
    "from tvm import relax as rx\n",
    "from tvm import relay, tir\n",
    "from tvm.relax.analysis import get_var2val\n",
    "import tvm.testing\n",
    "from tvm.relax.dpl import *\n",
    "\n",
    "bind_to_dataflow_var = tvm.testing.parameter(\n",
    "    by_dict={\"var-to-var\": False, \"var-to-dataflow-var\": True}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "```{note} rewrite_call 应避免生成琐碎的 \"y = x\" 绑定\n",
    "\n",
    "这可能并非在所有情况下都可行，并且遵循与 CanonicalizeBindings 相同的规则。例如，将 `relax.Var` 绑定到 `relax.DataflowVar` 可能无法移除，以确保 `relax.DataflowVar` 仅在 `DataflowBlock` 内使用。\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def test_rewrite_without_trivial_binding(bind_to_dataflow_var):\n",
    "    \"\"\"rewrite_call should avoid producing trivial \"y = x\" bindings\n",
    "\n",
    "    This may not be possible in all cases, and follows the same\n",
    "    rules as CanonicalizeBindings.  For example, a `relax.Var` is\n",
    "    bound to a `relax.DataflowVar` may not be removed, to ensure\n",
    "    that the `relax.DataflowVar` is only used within a\n",
    "    `DataflowBlock`.\n",
    "    \"\"\"\n",
    "\n",
    "    if bind_to_dataflow_var:\n",
    "\n",
    "        @R.function(private=True)\n",
    "        def before(x: R.Tensor((1024,))):\n",
    "            with R.dataflow():\n",
    "                a = R.add(x, x)\n",
    "                b = R.reshape(a, (1024,))\n",
    "                R.output(b)\n",
    "            return b\n",
    "\n",
    "        @R.function(private=True)\n",
    "        def expected(x: R.Tensor((1024,))):\n",
    "            with R.dataflow():\n",
    "                b = R.add(x, x)\n",
    "                R.output(b)\n",
    "            return b\n",
    "\n",
    "    else:\n",
    "\n",
    "        @R.function(private=True)\n",
    "        def before(x: R.Tensor((1024,))):\n",
    "            a = R.add(x, x)\n",
    "            b = R.reshape(a, (1024,))\n",
    "            return b\n",
    "\n",
    "        @R.function(private=True)\n",
    "        def expected(x: R.Tensor((1024,))):\n",
    "            a = R.add(x, x)\n",
    "            return a\n",
    "\n",
    "    pattern_arg = wildcard()\n",
    "    pattern_shape_expr = wildcard()\n",
    "    pattern = is_op(\"relax.reshape\")(pattern_arg, pattern_shape_expr)\n",
    "\n",
    "    def rewriter(expr, matches):\n",
    "        arg = matches[pattern_arg]\n",
    "        shape_expr = matches[pattern_shape_expr]\n",
    "\n",
    "        if tvm.ir.structural_equal(arg.struct_info.shape, shape_expr):\n",
    "            return arg\n",
    "        else:\n",
    "            return expr\n",
    "\n",
    "    after = rewrite_call(pattern, rewriter, before)\n",
    "    tvm.ir.assert_structural_equal(after, expected)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ai",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
